library(SingleCellExperiment)
## Loading required package: SummarizedExperiment
## Loading required package: MatrixGenerics
## Loading required package: matrixStats
## 
## Attaching package: 'MatrixGenerics'
## The following objects are masked from 'package:matrixStats':
## 
##     colAlls, colAnyNAs, colAnys, colAvgsPerRowSet, colCollapse,
##     colCounts, colCummaxs, colCummins, colCumprods, colCumsums,
##     colDiffs, colIQRDiffs, colIQRs, colLogSumExps, colMadDiffs,
##     colMads, colMaxs, colMeans2, colMedians, colMins, colOrderStats,
##     colProds, colQuantiles, colRanges, colRanks, colSdDiffs, colSds,
##     colSums2, colTabulates, colVarDiffs, colVars, colWeightedMads,
##     colWeightedMeans, colWeightedMedians, colWeightedSds,
##     colWeightedVars, rowAlls, rowAnyNAs, rowAnys, rowAvgsPerColSet,
##     rowCollapse, rowCounts, rowCummaxs, rowCummins, rowCumprods,
##     rowCumsums, rowDiffs, rowIQRDiffs, rowIQRs, rowLogSumExps,
##     rowMadDiffs, rowMads, rowMaxs, rowMeans2, rowMedians, rowMins,
##     rowOrderStats, rowProds, rowQuantiles, rowRanges, rowRanks,
##     rowSdDiffs, rowSds, rowSums2, rowTabulates, rowVarDiffs, rowVars,
##     rowWeightedMads, rowWeightedMeans, rowWeightedMedians,
##     rowWeightedSds, rowWeightedVars
## Loading required package: GenomicRanges
## Loading required package: stats4
## Loading required package: BiocGenerics
## 
## Attaching package: 'BiocGenerics'
## The following objects are masked from 'package:stats':
## 
##     IQR, mad, sd, var, xtabs
## The following objects are masked from 'package:base':
## 
##     anyDuplicated, aperm, append, as.data.frame, basename, cbind,
##     colnames, dirname, do.call, duplicated, eval, evalq, Filter, Find,
##     get, grep, grepl, intersect, is.unsorted, lapply, Map, mapply,
##     match, mget, order, paste, pmax, pmax.int, pmin, pmin.int,
##     Position, rank, rbind, Reduce, rownames, sapply, setdiff, sort,
##     table, tapply, union, unique, unsplit, which.max, which.min
## Loading required package: S4Vectors
## 
## Attaching package: 'S4Vectors'
## The following object is masked from 'package:utils':
## 
##     findMatches
## The following objects are masked from 'package:base':
## 
##     expand.grid, I, unname
## Loading required package: IRanges
## Loading required package: GenomeInfoDb
## Loading required package: Biobase
## Welcome to Bioconductor
## 
##     Vignettes contain introductory material; view with
##     'browseVignettes()'. To cite Bioconductor, see
##     'citation("Biobase")', and for packages 'citation("pkgname")'.
## 
## Attaching package: 'Biobase'
## The following object is masked from 'package:MatrixGenerics':
## 
##     rowMedians
## The following objects are masked from 'package:matrixStats':
## 
##     anyMissing, rowMedians
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.3     ✔ readr     2.1.4
## ✔ forcats   1.0.0     ✔ stringr   1.5.0
## ✔ ggplot2   3.4.3     ✔ tibble    3.2.1
## ✔ lubridate 1.9.3     ✔ tidyr     1.3.0
## ✔ purrr     1.0.2
## ── Conflicts ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ lubridate::%within%() masks IRanges::%within%()
## ✖ dplyr::collapse()     masks IRanges::collapse()
## ✖ dplyr::combine()      masks Biobase::combine(), BiocGenerics::combine()
## ✖ dplyr::count()        masks matrixStats::count()
## ✖ dplyr::desc()         masks IRanges::desc()
## ✖ tidyr::expand()       masks S4Vectors::expand()
## ✖ dplyr::filter()       masks stats::filter()
## ✖ dplyr::first()        masks S4Vectors::first()
## ✖ dplyr::lag()          masks stats::lag()
## ✖ ggplot2::Position()   masks BiocGenerics::Position(), base::Position()
## ✖ purrr::reduce()       masks GenomicRanges::reduce(), IRanges::reduce()
## ✖ dplyr::rename()       masks S4Vectors::rename()
## ✖ lubridate::second()   masks S4Vectors::second()
## ✖ lubridate::second<-() masks S4Vectors::second<-()
## ✖ dplyr::slice()        masks IRanges::slice()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(glue)
## 
## Attaching package: 'glue'
## 
## The following object is masked from 'package:SummarizedExperiment':
## 
##     trim
## 
## The following object is masked from 'package:GenomicRanges':
## 
##     trim
## 
## The following object is masked from 'package:IRanges':
## 
##     trim
library(lemur)
## 
## Attaching package: 'lemur'
## 
## The following object is masked from 'package:dplyr':
## 
##     vars
## 
## The following object is masked from 'package:ggplot2':
## 
##     vars
source("util.R")
variation <- "random_holdout_hvg"
dataset_labels <- c("angelidis" = "Angelidis", "aztekin" = "Aztekin", "bunis" = "Bunis",
                    "goldfarbmuren" = "Goldfarbmuren", "hrvatin" = "Hrvatin", 
                    "jakel" = "Jakel", "sathyamurthy" = "Sathyamurthy", "kang" = "Kang",
                    "skinnider" = "Skinnider", "bhattacherjee" = "Bhattacherjee", "canogamez" = "Canogamez",
                    "reyfman" = "Reyfman", "mouse_gastrulation" = "Pijuan-Sala")
method_labels <-  c("PCA" = "PCA", "harmony" = "Harmony", "invertible_harmony" = "Param. Harmony",
                    "multiCondPCA" = "Rigid LEMUR", "lemur" = "LEMUR", "CPA" = "CPA")

Integration

int_data <- read_tsv(glue("../benchmark/output/integration_results-{variation}.tsv")) %>%
  filter(method != "CPA" & method != "CPA_large") %>%
  mutate(method = ifelse(method == "CPA_kangparams", "CPA", method))
## Rows: 169 Columns: 17
## ── Column specification ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr  (4): data, method, status, comparison
## dbl (12): mmd, wasserstein, mix, ARI, AMI, NMI, celltypeSS_over_totalSS, bat...
## lgl  (1): job
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
int_data
## # A tibble: 143 × 17
##    data      method     job   status     mmd wasserstein   mix   ARI   AMI   NMI
##    <chr>     <chr>      <lgl> <chr>    <dbl>       <dbl> <dbl> <dbl> <dbl> <dbl>
##  1 angelidis CPA        NA    done   0.00570       1.03   8.64 0.435 0.651 0.658
##  2 angelidis CPA        NA    done   0.00745       1.02   8.96 0.407 0.614 0.635
##  3 angelidis harmony    NA    done   0.0215        0.898  5.27 0.667 0.835 0.838
##  4 angelidis PCA        NA    done   0.0422        0.992  2.67 0.668 0.837 0.840
##  5 angelidis PCA        NA    done   0.0486        0.953  2.86 0.684 0.833 0.842
##  6 angelidis lemur      NA    done   0.0188        0.885  4.49 0.793 0.880 0.883
##  7 angelidis lemur      NA    done   0.0239        0.849  4.83 0.748 0.862 0.870
##  8 angelidis invertibl… NA    done   0.0210        0.890  5.08 0.664 0.838 0.841
##  9 angelidis invertibl… NA    done   0.0264        0.850  5.48 0.719 0.847 0.856
## 10 angelidis multiCond… NA    done   0.0147        0.916  1.29 0.858 0.944 0.945
## # ℹ 133 more rows
## # ℹ 7 more variables: celltypeSS_over_totalSS <dbl>,
## #   batchSS_over_totalSS <dbl>, comparison <chr>, `stat-elapsed` <dbl>,
## #   `stat-user` <dbl>, `stat-sys` <dbl>, `stat-max_mem_kbytes` <dbl>
methods <- c("PCA", "harmony", "invertible_harmony", "multiCondPCA", "lemur", "CPA")
setdiff(unique(int_data$method), methods)
## character(0)
int_metrics <- vars(mmd, wasserstein, mix, ARI, AMI, NMI, celltypeSS_over_totalSS, batchSS_over_totalSS)

perf_data <- int_data %>%
  mutate(mmd = log(mmd)) %>%
  dplyr::select(data, method, comparison, !!! int_metrics) %>%
  pivot_longer(all_of(map_chr(int_metrics, rlang::as_name)), names_to = "metric") 
condition_labels <- c("holdout_vs_train" = "Holdout", "train_vs_train" = "Training")

suppl_int_signal_plt <- perf_data %>%
  mutate(method = factor(method, levels = methods)) %>%
  filter(metric %in% c("ARI",  "AMI", "NMI", "celltypeSS_over_totalSS")) %>%
  mutate(metric = factor(metric, levels = c("ARI",  "AMI", "NMI", "celltypeSS_over_totalSS"),
                         labels = c("ARI", "AMI", "NMI", "Var. Expl.\nCell type"))) %>%
  ggplot(aes(x = method, y = value)) +
    geom_point(aes(color = comparison), size = 0.4) +
    ggh4x::facet_grid2(vars(metric), vars(data), scales = "free_y", labeller = labeller(data = as_labeller(dataset_labels)), 
               strip = ggh4x::strip_vanilla(clip = "off")) +
    scale_x_discrete(labels = method_labels) +
    scale_color_discrete(labels = condition_labels) +
    guides(x = guide_axis(angle = 90)) +
    theme(panel.grid.major.y = element_line(colour = "grey50"), axis.title = element_blank(),
          panel.spacing.y = unit(5, "mm")) +
    labs(title = "(B) Biological signal retention")

suppl_int_mix_plt <- perf_data %>%
  mutate(method = factor(method, levels = methods)) %>%
  filter(metric %in% c("mmd", "wasserstein", "mix", "batchSS_over_totalSS")) %>%
  mutate(metric = factor(metric, levels = c("mix", "mmd", "wasserstein", "batchSS_over_totalSS"),
                         labels = c("kNN Mix", "$\\log$ MMD", "Wasserstein",  "Var. Expl.\nCondition"))) %>%
  ggplot(aes(x = method, y = value)) +
    geom_point(aes(color = comparison), size = 0.4) +
    ggh4x::facet_grid2(vars(metric), vars(data), scales = "free_y", labeller = labeller(data = as_labeller(dataset_labels)), 
               strip = ggh4x::strip_vanilla(clip = "off")) +
    scale_x_discrete(labels = method_labels) +
    scale_color_discrete(labels = condition_labels) +
    guides(x = guide_axis(angle = 90)) +
    theme(panel.grid.major.y = element_line(colour = "grey50"), axis.title = element_blank(),
          panel.spacing.y = unit(5, "mm")) +
    labs(title = "(A) Integration of conditions")

suppl_int_ss_plt <- perf_data %>%
  mutate(method = factor(method, levels = methods)) %>%
  filter(metric %in% c("celltypeSS_over_totalSS", "batchSS_over_totalSS")) %>%
  pivot_wider(id_cols = c(data, method, comparison), names_from = metric, values_from = value) %>%
  mutate(rel_SS = batchSS_over_totalSS / celltypeSS_over_totalSS) %>%
  ggplot(aes(x = method, y = rel_SS)) +
    geom_point(aes(color = comparison), size = 0.4) +
    ggh4x::facet_wrap2(vars(data), scales = "fixed", labeller = labeller(data = as_labeller(dataset_labels)),
                       strip = ggh4x::strip_vanilla(clip = "off"), nrow = 1) +
    scale_x_discrete(labels = method_labels) +
    scale_color_discrete(labels = condition_labels) +
    guides(x = guide_axis(angle = 90)) +
    theme(panel.grid.major.y = element_line(colour = "grey50"), axis.title = element_blank()) +
    labs(title = "(C) Ration of variance explained $\\frac{\\text{condition}}{\\text{cell type}}$")


plot_assemble(
  add_plot(suppl_int_mix_plt  + guides(color = "none"), x = 0, y = 0, width = 200, height = 80),
  add_plot(suppl_int_signal_plt + guides(color = "none"), x = 0, y = 80, width = 200, height = 80),
  add_plot(suppl_int_ss_plt + guides(color = "none"), x = 0, y = 160, width = 200, height = 40),
  add_plot(cowplot::get_legend(suppl_int_ss_plt + labs(color = "") + theme(legend.position = "bottom")), x = 0, y = 200, height = 10, width = 50),

  
  width = 200, height = 210, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl_integration_all_metrics.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]
## gg[gg4]
bootstrap <- function(data, FUN, n_iterations = 1000, map_fun = purrr::map){
  n <- nrow(data)
  tibble(bootstrap_iteration = seq_len(n_iterations),
         value = map_fun(seq_len(n_iterations), \(idx){
            sel <- sample.int(n, size = n, replace = TRUE)
            FUN(data[sel,])
          }))
}

faithful %>%
  as_tibble() %>%
  mutate(group = eruptions < 3) %>%
  reframe(means = bootstrap(cbind(eruptions, waiting), \(mat) matrix(colMeans(mat), nrow = 1), n_iterations = 100),
          .by = group) %>%
  unpack(means) %>%
  unnest(value) %>%
  ggplot(aes(x = value[,1], y = value[,2])) +
    geom_density2d(aes(color = group)) +
    geom_point(data = faithful, aes(x = eruptions, y = waiting))

int_bio_df <- left_join(int_data %>%  dplyr::select(data, method, comparison, mix),
          perf_data %>%
            filter(metric == "ARI") %>%
            dplyr::rename(ARI = value) %>%
            dplyr::select(-metric), by = c("data", "method", "comparison")) %>%
  filter(comparison == "holdout_vs_train" | method == "harmony") %>%
  filter(method %in% c("PCA", "harmony", "lemur", "CPA")) %>%
  mutate(method = factor(method, c("lemur", "PCA", "harmony", "CPA"))) %>%
  mutate(avg_mix = mean(mix),
         avg_ARI = mean(ARI), .by = c(data)) 
  
set.seed(1)
int_bio_pl <- int_bio_df %>%
  mutate(method = as.character(method)) %>%
  reframe(bootstrap_samples = bootstrap(cbind(mix / avg_mix, ARI / avg_ARI), \(mat) matrix(colMeans(mat), nrow = 1), n_iterations = 1000),
        .by = c(method, comparison)) %>%
  unpack(bootstrap_samples) %>%
  unnest(value) %>%
  ggplot(aes(x = value[,1], y = value[,2])) +
    geom_density2d(aes(color = method), adjust = 3, contour_var = "ndensity", breaks = seq(0.9, 0.2, length.out = 3)) +
    # geom_point(data = int_bio_df, aes(x =  mix / avg_mix, y = ARI / avg_ARI, color = stage(method, after_scale = colorspace::lighten(colour, 0.3))), size = 0.1) +
    ggrepel::geom_text_repel(data = . %>% summarize(value = matrix(colMeans(value), nrow = 1), .by = method), aes(label = method_labels[method]), 
                             size = font_size_small / .pt, point.padding = unit(3, "mm") ) +
    annotation_custom(grid::polylineGrob(x = c(0.91, 0.99), y = c(0.91, 0.99), gp = grid::gpar(fill = "black"), arrow =  grid::arrow(ends = "last", type = "closed", angle = 20, length = unit(10 / 7, "mm")))) +
    labs(y = "Biological signal retention (Relative ARI)",x = "Integration of conditions (Relative $k$-NN mixing)") +
    scale_x_continuous(limits = c(0, 1.7), breaks = c(0, 1, 2), expand = expansion(add = c(0, 0.05))) +
    scale_y_continuous(limits = c(0, 1.3), breaks = c(0, 1), expand = expansion(add = c(0, 0.05))) +
    guides(color = "none") +
    coord_fixed() 

int_bio_pl

kang_vis <- read_tsv("../benchmark/output/kang_visualization-random_holdout_hvg.tsv") %>%
  filter(method != "CPA" & method != "CPA_large") %>%
  mutate(method = ifelse(method == "CPA_kangparams", "CPA", method))
## Rows: 167745 Columns: 15
## ── Column specification ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr (7): data, method, status, name, covariate, sample, celltype
## dbl (6): umap1, umap2, tsne1, tsne2, pca1, pca2
## lgl (2): job, is_holdout
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
kang_vis
## # A tibble: 143,072 × 15
##    data  method job   status name   covariate sample celltype is_holdout   umap1
##    <chr> <chr>  <lgl> <chr>  <chr>  <chr>     <chr>  <chr>    <lgl>        <dbl>
##  1 kang  CPA    NA    done   AAACA… ctrl      patie… CD14+ M… FALSE      -0.202 
##  2 kang  CPA    NA    done   AAACA… ctrl      patie… CD14+ M… FALSE       2.21  
##  3 kang  CPA    NA    done   AAACA… ctrl      patie… CD4 T c… FALSE       0.644 
##  4 kang  CPA    NA    done   AAACA… ctrl      patie… CD14+ M… FALSE       1.04  
##  5 kang  CPA    NA    done   AAACA… ctrl      patie… Dendrit… FALSE       2.78  
##  6 kang  CPA    NA    done   AAACA… ctrl      patie… CD4 T c… FALSE       0.0801
##  7 kang  CPA    NA    done   AAACA… ctrl      patie… CD14+ M… FALSE      -0.410 
##  8 kang  CPA    NA    done   AAACA… ctrl      patie… CD4 T c… FALSE       1.84  
##  9 kang  CPA    NA    done   AAACA… ctrl      patie… CD14+ M… FALSE       0.414 
## 10 kang  CPA    NA    done   AAACA… ctrl      patie… CD4 T c… FALSE      -0.478 
## # ℹ 143,062 more rows
## # ℹ 5 more variables: umap2 <dbl>, tsne1 <dbl>, tsne2 <dbl>, pca1 <dbl>,
## #   pca2 <dbl>
kang_umap_annot_df <- perf_data %>%
  filter(data == "kang") %>%
  filter(metric %in% c("ARI", "mix")) %>%
  filter(comparison == "holdout_vs_train" | method == "harmony") %>%
  filter(method %in% c("PCA", "harmony", "lemur", "CPA")) %>%
  mutate(method = factor(method, c("lemur", "PCA", "harmony", "CPA"))) %>%
  pivot_wider(names_from = "metric", values_from = "value")

kang_plot <- kang_vis %>%
  filter(method %in% c("PCA", "harmony", "lemur", "CPA")) %>%
  mutate(method = factor(method, c("lemur", "PCA", "harmony", "CPA"))) %>%
  mutate(is_holdout = ! is.na(is_holdout) & is_holdout) %>%
  sample_frac(n = 1) %>%
  ggplot(aes(x = umap1, y = umap2)) +
    ggrastr::rasterize(geom_point(aes(color = covariate), size = 0.05, stroke = 0), dpi = 600) +
    geom_text(data = kang_umap_annot_df, aes(x = 0, y = Inf, label = glue("$k\\textrm{{NN}} = {sprintf('%.2f', mix)}$")), 
          size = font_size_small / .pt, hjust = 0.5, vjust = 1.5) +
    geom_text(data = kang_umap_annot_df, aes(x = 0, y = Inf, label = glue("$\\textrm{{ARI}} = {sprintf('%.2f', ARI)}$")), 
            size = font_size_small / .pt, hjust = 0.5, vjust = 3) +
    scale_color_manual(values = c("ctrl" = "#FC8D62", "stim" = "#8DA0CB")) +
    facet_wrap(vars(method), nrow = 1, labeller = as_labeller(c("lemur" = "LEMUR", "CPA" = "CPA", "PCA" = "PCA", "harmony" = "Harmony"))) +
    guides(color = guide_legend(override.aes = list(size = 2))) +
    small_axis(label = "UMAP", fontsize = font_size_small) +
    labs(color = "Cell condition",
         title = "(A) Latent space ($Z$) of Lupus patient samples treated with IFN-$\\beta$ (Kang)") +
    theme(legend.position = "bottom")

kang_plot

kang_plot_cell_type <- kang_vis %>%
  filter(method == "lemur") %>%
  sample_frac(n = 1) %>%
  ggplot(aes(x = umap1, y = umap2)) +
    ggrastr::rasterize(geom_point(aes(color = celltype), size = 0.05, stroke = 0), dpi = 600) +
    guides(color = "none") +
    small_axis(label = "UMAP", fontsize = font_size_small) +
    labs(title = "(C) $Z$ colored by\ncell type")

kang_plot_cell_type

Prediction

pred_data <- read_tsv(glue("../benchmark/output/prediction_results-{variation}.tsv")) %>%
  filter(method != "CPA" & method != "CPA_large") %>%
  mutate(method = ifelse(method == "CPA_kangparams", "CPA", method))
## Rows: 416 Columns: 19
## ── Column specification ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr  (4): data, method, status, comparison
## dbl (14): l2_mean, l2_sd, r2_mean, r2_sd, l2_mean_per_celltype, l2_sd_per_ce...
## lgl  (1): job
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
pred_methods <- c("linear", "PCA", "harmony", "invertible_harmony", "multiCondPCA", "lemur", "CPA", "no_change")
pred_method_labels <-  c("linear" = "Linear", "PCA" = "PCA", "harmony" = "Harmony", "invertible_harmony" = "Param. Harmony",
                    "multiCondPCA" = "Rigid LEMUR", "lemur" = "LEMUR", "CPA" = "CPA", "no_change" = "Identity")
pred_metrics <- c("l2_mean", "l2_mean_per_celltype", "l2_sd", "l2_sd_per_celltype", "r2_mean",
                         "r2_mean_per_celltype", "r2_sd", "r2_sd_per_celltype", "mmd", "wasserstein")
pred_metrics_labels <- c("l2_mean" = "$L_2$ mean", "l2_mean_per_celltype" = "$L_2$ mean\ncell type", "l2_sd" = "$L_2$ of S.D.",
                  "l2_sd_per_celltype" = "$L_2$ of S.D.\ncell type", "mmd" = "$\\log$ MMD", "r2_mean" = "$R^2$ mean",
                  "r2_mean_per_celltype" = "$R^2$ mean\ncell type", "r2_sd" = "$R^2$ of S.D.",
                  "r2_sd_per_celltype" = "$R^2$ of S.D.\ncell type", "wasserstein" = "Wasserstein\ndistance")
supl_pred_df <- pred_data %>%
  dplyr::select(-starts_with("stat-")) %>%
  mutate(mmd = log(mmd)) %>%
  pivot_longer(-c(method, job, data, status, comparison), names_to = "metric") %>%
  mutate(method = factor(method, levels = pred_methods)) %>%
  mutate(metric = factor(metric, levels = pred_metrics)) %>%
  mutate(is_perf = ifelse(str_starts(comparison, "holdout"), "holdout", "training")) 
  
make_supl_pred_plot <- function(data){
  ggplot(data, aes(x = method, y = value)) +
    ggbeeswarm::geom_quasirandom(aes(color = is_perf), width = 0.2, size = 0.4) +
    ggh4x::facet_grid2(vars(metric), vars(data), scales = "free_y", independent = "y",
                       labeller = labeller(metric = as_labeller(pred_metrics_labels), data = as_labeller(dataset_labels)),
                       strip = ggh4x::strip_vanilla(clip = "off")) +
    scale_x_discrete(labels = pred_method_labels) +
    scale_y_continuous(breaks = scales::breaks_pretty(n = 3)) +
    guides(x = guide_axis(angle = 90)) +
    labs(color = "") +
    theme(panel.grid.major.y = element_line(colour = "grey50"), axis.title = element_blank(),
          panel.spacing.y = unit(5, "mm"), strip.text.y = element_text(size = font_size_tiny)) 
}

pl1 <- supl_pred_df %>%
  filter(as.integer(as.factor(data)) <= 7) %>%
  make_supl_pred_plot()

pl2 <- supl_pred_df %>%
  filter(as.integer(as.factor(data)) > 7) %>%
  make_supl_pred_plot()


plot_assemble(
  add_plot(pl1 + guides(color = "none"), x = 0, y = 0, width = 170, height = 120),
  add_plot(pl2 , x = 0, y = 125, width = 160, height = 120),
  
  width = 170, height = 245, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl_prediction_all_metrics.pdf"
)
## Warning: Removed 245 rows containing missing values (`position_quasirandom()`).
## gg[gg1]
## Warning: Removed 48 rows containing missing values (`position_quasirandom()`).
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values

## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## gg[gg2]
pred_data_long <- pred_data %>%
  dplyr::select(data, method, comparison, mmd, wasserstein, starts_with("l2"), starts_with("r2")) %>%
  pivot_longer(c( mmd, wasserstein, starts_with("l2"), starts_with("r2")), names_to = "metric") 
perf_pred_plot <- pred_data_long %>%
  filter(metric %in% c("l2_mean", "l2_mean_per_celltype")) %>%
  filter(str_starts(comparison, "holdout[12]_")) %>%
  filter(method %in% c("no_change", "lemur", "CPA")) %>%
  mutate(method = factor(method, c("lemur", "no_change", "CPA"))) %>%
  mutate(method = fct_rev(method)) %>%
  mutate(metric = factor(metric, levels = c("l2_mean", "l2_mean_per_celltype"))) %>%
  mutate(mean_value = mean(value), .by = c(data, metric)) %>%
  mutate(rel_perf = value / mean_value) %>%
  ggplot(aes(x = rel_perf, y = method)) +
    geom_vline(xintercept = 1, linewidth = 0.3, color = "black") +
    ggbeeswarm::geom_quasirandom(color = "grey", width = 0.2, size = 0.5) +
    stat_summary(geom = "pointrange", fun.data = mean_se, color = "red", size = 0.4) +
    small_arrow(position = c(0.1, 0.01), offset = 0.03) +
    facet_wrap(vars(metric), nrow = 2,  labeller = as_labeller(c("l2_mean" = "Overall error\n($L_2$ distance)", 
                                                                 "l2_mean_per_celltype" = "Error per cell type\n(Mean $L_2$ distance)"))) +
    scale_x_continuous(expand = expansion(mult = c(0, 0.1)), breaks = c(0, 1, 2)) +
    scale_y_discrete(labels = c("lemur" = "LEMUR", "CPA" = "CPA", "no_change" = "Identity")) +
    coord_cartesian(xlim = c(0, 2)) +
    theme(axis.title.y = element_blank(), plot.title.position = "plot") +
    labs(title = "(E) Prediction performance", x = "Relative performance across 13 datasets")

perf_pred_plot
## Warning: Removed 12 rows containing non-finite values (`stat_summary()`).
## Orientation inferred to be along y-axis; override with
## `position_quasirandom(orientation = 'x')`
## Warning: Removed 12 rows containing missing values (`position_quasirandom()`).

kang_pred <- read_tsv("../benchmark/output/kang-detailed_prediction_results_random_holdout_hvg.tsv") %>%
  pivot_longer(c(starts_with("stim"), starts_with("ctrl")), names_sep = "-", names_to = c("condition", "cell_type"), values_to = "expression") %>%
  filter(method != "CPA" & method != "CPA_large") %>%
  mutate(method = ifelse(method == "CPA_kangparams", "CPA", method))
## Rows: 4500 Columns: 18
## ── Column specification ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr  (2): method, gene
## dbl (16): stim-NK cells, stim-CD4 T cells, stim-CD14+ Monocytes, stim-CD8 T ...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
l2_mean_dist_df <- pred_data_long %>% 
  filter(data == "kang") %>%
  filter(metric == "l2_mean_per_celltype") %>% 
  filter(comparison == "holdout2_vs_obs2") %>%
  dplyr::rename(method.pred = method) %>%
  filter(method.pred %in% c("no_change", "lemur", "CPA")) %>%
  mutate(method.pred = factor(method.pred, c("lemur", "no_change", "CPA")))

ct_scatter_plot <- tidylog::inner_join(filter(kang_pred, method != "obs"),
           filter(kang_pred, method == "obs"), by = c("gene", "condition", "cell_type"), suffix = c(".pred", ".obs")) %>%
  filter(method.pred %in% c("no_change", "lemur", "CPA")) %>%
  filter(condition == "stim") %>%
  mutate(method.pred = factor(method.pred, c("lemur", "no_change", "CPA"))) %>%
  ggplot(aes(x = expression.obs, y = expression.pred)) +
    geom_abline() +
    ggrastr::rasterize(geom_point(aes(color = cell_type), size = 0.4, stroke = 0), dpi = 300) +
    geom_text(data = l2_mean_dist_df, aes(x = 12, y = 22, label = glue("$L_2 = {sprintf('%.2f', value)}$")), 
              size = font_size_small / .pt, halign = 0.5, ) +
    facet_wrap(vars(method.pred), nrow = 1, labeller = as_labeller(c("lemur" = "LEMUR", "CPA" = "CPA", "no_change" = "Identity"))) +
    guides(color = guide_legend(override.aes = list(size = 2))) +
    coord_fixed() +
    labs(color = "",
         title = "(D) Predicted vs. observed gene expression per cell type",
         x = "Observed expression", y = "Predicted expression") +
    theme(legend.position = "bottom", plot.title.position =  "plot")
## inner_join: added 4 columns (method.pred, expression.pred, method.obs, expression.obs)
##             > rows only in x  (     0)
##             > rows only in y  (     0)
##             > matched rows     56,000
##             >                 ========
##             > rows total       56,000
## Warning in geom_text(data = l2_mean_dist_df, aes(x = 12, y = 22, label =
## glue("$L_2 = {sprintf('%.2f', value)}$")), : Ignoring unknown parameters:
## `halign`
ct_scatter_plot

tidylog::inner_join(filter(kang_pred, method != "obs"),
           filter(kang_pred, method == "obs"), by = c("gene", "condition", "cell_type"), suffix = c(".pred", ".obs")) %>%
  filter(condition == "stim") %>%
  ggplot(aes(x = expression.obs, y = expression.pred)) +
    geom_abline() +
    ggrastr::rasterize(geom_point(aes(color = cell_type), size = 0.4, stroke = 0), dpi = 300) +
    facet_wrap(vars(method.pred), nrow = 1) +
    guides(color = guide_legend(override.aes = list(size = 2))) +
    coord_fixed() +
    theme(legend.position = "bottom", plot.title.position =  "plot")
## inner_join: added 4 columns (method.pred, expression.pred, method.obs, expression.obs)
##             > rows only in x  (     0)
##             > rows only in y  (     0)
##             > matched rows     56,000
##             >                 ========
##             > rows total       56,000

Neighborhood identification

lemur_fit <- readRDS("../benchmark/output/differential-expression-kang_lemur_fit.RDS")
nei <- as_tibble(readRDS("../benchmark/output/differential-expression-kang_lemur_fit-neighborhood.RDS"))
set.seed(1)
umap <- uwot::umap(as.matrix(t(lemur_fit$embedding)))
sel_gene <- 8

de_simulated_pl <- as_tibble(colData(lemur_fit)) %>%
  mutate(umap) %>%
  mutate(is_de = lemur_fit$rowData$is_de_cell[[sel_gene]]) %>%
  mutate(is_de = ifelse(is_de, "DE", "Not DE")) %>%
  mutate(is_de = fct_rev(is_de)) %>%
  ggplot(aes(x = umap[,1], y = umap[,2])) +
    ggrastr::rasterize(geom_point(aes(color = is_de), size = 0.05, stroke = 0), dpi = 600) +
    small_axis(label = "UMAP", fontsize = font_size_small) +
    labs(#title = "(F) Simulated DE", 
         color = "") +
    guides(color = guide_legend(override.aes = list(size = 1))) +
    theme(legend.position = "top")

de_expr_pl <- as_tibble(colData(lemur_fit)) %>%
  mutate(umap) %>%
  mutate(expr = assay(lemur_fit, "logcounts")[sel_gene,]) %>%
  mutate(inside = colnames(lemur_fit) %in% nei$neighborhood[[sel_gene]]) %>%
  ggplot(aes(x = umap[,1], y = umap[,2])) +
    ggrastr::rasterize(geom_point(aes(color = expr), size = 0.1, stroke = 0), dpi = 600) +
    # scale_color_viridis_c() +
    colorspace::scale_color_continuous_sequential(limits = c(0, quantile(assay(lemur_fit, "logcounts")[sel_gene,], c(0.98))), oob = scales::oob_squish, palette = "Purples 2", n.breaks = 3) +
    small_axis(label = "", fontsize = font_size_small, arrow_length = 5) +
    labs(title = "", color = "Expr.") +
    facet_wrap(vars(fake_condition), ncol = 1, labeller = as_labeller(c("fake_ctrl" = "Ctrl.", "fake_trt" = "Trt."))) +
    guides(color = guide_colorbar(barwidth = unit(1, "mm"))) 

de_pred_pl <- as_tibble(colData(lemur_fit)) %>%
  mutate(umap) %>%
  mutate(de_pred = assay(lemur_fit, "DE")[sel_gene,]) %>%
  ggplot(aes(x = umap[,1], y = umap[,2])) +
    ggrastr::rasterize(geom_point(aes(color = de_pred), size = 0.05, stroke = 0), dpi = 600) +
    # scale_colour_gradient2_rev(limits = c(-2.5, 2.5), oob = scales::oob_squish, breaks = c(-2, 0, 2), mid = "lightgrey") +
    scale_color_de_gradient(limits = c(-2.5, 2.5), breaks = c(-2, 0, 2), mid_width = 0.2) +
    small_axis(label = "UMAP", fontsize = font_size_small) +
    labs(#title = "(G) Predicted DE + Neighborhood", 
         color = "$\\Delta$") +
    guides(color = guide_colorbar(barwidth = unit(1, "mm"))) 

de_nei_pl <- as_tibble(colData(lemur_fit)) %>%
  mutate(umap) %>%
  mutate(is_de = lemur_fit$rowData$is_de_cell[[sel_gene]]) %>%
  mutate(inside = colnames(lemur_fit) %in% nei$neighborhood[[sel_gene]]) %>%
  ggplot(aes(x = umap[,1], y = umap[,2])) +
    ggrastr::rasterize(geom_point(aes(color = is_de), size = 0.05, stroke = 0), dpi = 600) +
    small_axis(label = "", fontsize = font_size_small, arrow_length = 5) +
    guides(color = "none") +
    labs(title = " ") +
    facet_wrap(vars(inside), ncol = 1, , labeller = as_labeller(c("TRUE" = "Inside", "FALSE" = "Outside")))

de_simulated_pl

de_pred_pl

de_expr_pl

de_nei_pl

prec_recall_df <- read_tsv("../benchmark/output/differential_expression-kang-recall_precision.tsv.gz") 
## Rows: 2200 Columns: 9
## ── Column specification ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr (3): data, vals, name
## dbl (6): adj_pval, TP, FP, FN, de_n_cells, nei_size
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
prec_recall_plot <- prec_recall_df %>%
  mutate(recall = TP / de_n_cells, 
         precision = TP / nei_size) %>%
  mutate(highlighted = case_when(
    data == "kang" & vals == "ctrl" & name == glue::glue("simulated_gene-{sel_gene}") ~ "special",
    adj_pval < 0.1 ~ "signif",
    TRUE ~ "not signif"
  )) %>%
  arrange(highlighted) %>%
  ggplot(aes(x = recall, y = precision)) +
    ggrastr::rasterize(geom_point(aes(color = highlighted, size =highlighted )), dpi = 300) +
    # geom_point(data = . %>% filter(highlighted == "special"), size = 0.7, color = "red") +
    scale_x_continuous(expand = expansion(0), breaks = c(0, 0.25, 0.5, 0.75, 1), labels = c("0", "0.25", "0.5", "0.75", "1")) +
    scale_y_continuous(expand = expansion(0), breaks = c(0, 0.5, 1), labels = c("0", "0.5", "1")) +
    scale_color_manual(values = c("special" = "red", "signif" = "black", "not signif" = "lightgrey"), 
                       labels = c("special" = "Example gene\nfrom F,G", "signif" = "signif. genes", "not signif" = "not signif. genes")) +
    scale_size_manual(values = c("special" = 0.7, "signif" = 0.4, "not signif" = 0.1)) +
    guides(color = guide_legend(override.aes = list(alpha = 1, size = 1.5)), size = "none") +
    coord_cartesian(clip = "off") +
    labs(title = "(H) Neighborhood Precision and Recall",
         subtitle = "Overlap of  neighborhood and simulated ground truth", 
         color = "", x = "Recall", y = "Precision") +
    theme(plot.title.position = "plot", legend.position = "bottom")

prec_recall_plot

Differential expression control

de_power_fdr <- read_tsv("../benchmark/output/differential_expression_fdr_power-kmeans.tsv.gz") %>%
  mutate(extra_settings = ifelse(str_starts(method, "lemur_edgeR_"), str_remove(method, "lemur_edgeR_"), NA)) %>%
  separate(method, into = c("method", "de_framework"))
## Rows: 4851 Columns: 7
## ── Column specification ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr (3): data, vals, method
## dbl (4): TPR, TP, FDR, nominal_fdr
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
## Warning: Expected 2 pieces. Additional pieces discarded in 2541 rows [64, 65,
## 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, ...].
de_power_fdr %>% distinct(data, vals)
## # A tibble: 11 × 2
##    data          vals   
##    <chr>         <chr>  
##  1 angelidis     24m    
##  2 angelidis     3m     
##  3 goldfarbmuren heavy  
##  4 goldfarbmuren never  
##  5 hrvatin       0h     
##  6 hrvatin       1h     
##  7 hrvatin       4h     
##  8 sathyamurthy  form   
##  9 kang          ctrl   
## 10 kang          stim   
## 11 reyfman       Control
fdr_control_plot <- de_power_fdr %>%
  filter(de_framework == "edgeR" & method == "lemur" & is.na(extra_settings)) %>%
  ggplot(aes(x = nominal_fdr, y = FDR, group = paste0(data, method, vals))) +
    geom_line(data = . %>% filter(method == "lemur"), color = "lightgrey", linewidth = 1) +
    geom_line(data = . %>% filter(method == "lemur") %>% summarize(FDR = mean(FDR, na.rm = TRUE), .by = c(nominal_fdr)), 
              aes(group = 1), color = "darkblue", linewidth = 1.3) +
    geom_abline() +
    coord_fixed(ylim = c(0, 0.3), xlim = c(0, 0.2)) +
    scale_x_continuous(expand = expansion(0), breaks = c(0, 0.1, 0.2), labels = c("0", "0.1", "0.2")) +
    scale_y_continuous(expand = expansion(0)) +
    labs(title = "(I) FDR control", 
         x = "Nominal FDR", y = "Observed FDR",
         subtitle = "Eleven datasets (\\textcolor{blue!70!black}{mean})") +
    theme(plot.title.position = "plot")

fdr_control_plot

power_plot <- de_power_fdr %>%
  filter(de_framework == "edgeR" & is.na(extra_settings)) %>%
  mutate(data = dataset_labels[data]) %>%
  mutate(data_vals = glue::glue("{data}\n{vals}")) %>%
  mutate(is_lemur = method == "lemur") %>%
  ggplot(aes(x = nominal_fdr, y = TPR)) +
    geom_hline(yintercept = 0, linewidth = 0.3) +
    geom_line(aes(color = is_lemur, group = paste0(method)), show.legend = FALSE) +
    scale_color_manual(values = c("TRUE" = "red", "FALSE" = "grey")) +
    ggh4x::facet_wrap2(vars(data_vals), nrow = 3, strip = ggh4x::strip_vanilla(clip = "off")) +
    coord_cartesian(ylim = c(0, 0.75), xlim = c(0, 0.2)) +
    scale_x_continuous(expand = expansion(add = c(0, 0.02)), breaks = c(0, 0.1, 0.2), labels = c("0", "0.1", "0.2")) +
    scale_y_continuous(expand = expansion(0)) +
    labs(title = "(J) DE-test Power", 
         x = "Nominal FDR", y = "True positive rate (TPR)",
         subtitle = "Comparison of \\textcolor{red}{LEMUR} against four other \\textcolor{gray!90}{DE. methods}") +
    theme(plot.title.position = "plot", strip.text = element_text(margin = margin(b = unit(-1, "cm"))))

power_plot

de_power_fdr_with_label <- de_power_fdr %>%
    mutate(es = extra_settings) %>%
    mutate(extra_settings_pretty = case_when(
    is.na(es) ~ "",
    str_starts(es, "nemb_") ~ glue("$\\#\\text{{dimensions}} = {str_remove(es, 'nemb_')}$"),
    str_starts(es, "testfrac_") ~ glue("$\\text{{test fraction}} = {as.numeric(str_remove(es, 'testfrac_')) * 100}\\%$"),
    es == "sel_contrast" ~ "contrast selection",
    es == "dir_contrast" ~ "contrast directions",
    es == "skip_al" ~ "$S=I$ (rigid)",
    es == "skip_multCondPCA" ~ "parametric Harmony",
    es == "sf_method_ratio" ~ "ratio size factor estimation",
    es == "count_split" ~ "count splitting",
    es == "notesttraining" ~ "no test/train split",
    TRUE ~ es
  )) %>%
  mutate(method = factor(method, levels = c("global", "celltype", "cluster", "miloDE", "lemur"),
                         labels = c("Global", "Cell type", "Cluster", "miloDE", "LEMUR"))) %>%
  mutate(method_label = case_when(
    is.na(extra_settings) ~ glue("{method} ({de_framework})"),
    !is.na(extra_settings) ~ glue("{method} ({de_framework})\nwith {extra_settings_pretty}")
  )) %>%
  mutate(data_vals = glue::glue("{data}\n{vals}"))
## Warning: There was 1 warning in `mutate()`.
## ℹ In argument: `extra_settings_pretty = case_when(...)`.
## Caused by warning:
## ! NAs introduced by coercion
fdr_control_plots <- de_power_fdr_with_label %>%
  group_by(method == "LEMUR") %>%
  group_map(\(data, key){
    ggplot(data, aes(x = nominal_fdr, y = FDR, group = paste0(data, vals, method_label))) +
      geom_line(color = "lightgrey", linewidth = 1) +
      geom_line(data = . %>% summarize(FDR = mean(FDR, na.rm = TRUE), .by = c(nominal_fdr, method_label)),
                aes(group = 1), color = "darkblue", linewidth = 1.3) +
      geom_abline() +
      coord_fixed(ylim = c(0, 0.3), xlim = c(0, 0.2)) +
      scale_x_continuous(expand = expansion(0), breaks = c(0, 0.1, 0.2), labels = c("0", "0.1", "0.2")) +
      scale_y_continuous(expand = expansion(0)) +
      ggh4x::facet_wrap2(vars(method_label), ncol = 7, strip = ggh4x::strip_vanilla(clip = "off")) +
      labs(x = "Nominal FDR", y = "Observed FDR") +
      theme(plot.title.position = "plot", panel.spacing.x = unit(5, "mm"))
  })

all_methods_power <- de_power_fdr_with_label %>%
  filter(nominal_fdr == 0.1) %>%
  mutate(FDR = ifelse(is.na(FDR), 0, FDR)) %>%
  mutate(controls_fdr = mean(FDR) < nominal_fdr, .by = c(method, de_framework, extra_settings)) %>%
  mutate(rel_TPR = TPR / mean(TPR), .by = data_vals) %>%
  mutate(method_label = str_replace(method_label, "\n", " ")) %>%
  mutate(method_label = fct_reorder(method_label, rel_TPR, .fun = mean)) %>%
  ggplot(aes(x = rel_TPR, y = method_label)) +
    geom_vline(xintercept = 1, linewidth = 0.3, color = "black") +
    ggbeeswarm::geom_quasirandom(color = "grey", width = 0.2, size = 0.5) +
    stat_summary(aes(color = controls_fdr), geom = "pointrange", fun.data = mean_se, size = 0.4) +
    ggh4x::facet_grid2(rows = vars(method), scales = "free_y", space = "free_y") +
    scale_color_manual(values = c("TRUE" = "red", "FALSE" = "lightgrey")) +
    labs(color = "FDR control on average", y = "", x = "Relative TPR") +
    theme(strip.text = element_blank(), panel.spacing.y = unit(3, "mm"))
fdr_control_plots
## [[1]]
## Warning: Removed 146 rows containing missing values (`geom_line()`).

## 
## [[2]]
## Warning: Removed 13 rows containing missing values (`geom_line()`).

all_methods_power
## Orientation inferred to be along y-axis; override with
## `position_quasirandom(orientation = 'x')`

plot_assemble(
  add_text("(A) FDR control for all methods", x = 2.7, y = 2, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(cowplot::plot_grid(plotlist = fdr_control_plots, ncol = 1, rel_heights = c(1,2)), x = 0, y = 5, width = 170, height = 120),
  add_text("(B) Average power for all methods at $\\textrm{FDR} = 10\\%$", x = 2.7, y = 125, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(all_methods_power, x = 0, y = 130, width = 130, height = 100),

  width = 170, height = 230, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl_all_fdr_power.pdf"
)
## Warning: Removed 146 rows containing missing values (`geom_line()`).
## Warning: Removed 13 rows containing missing values (`geom_line()`).
## gg[gg1]
## gg[gg2]
## gg[gg3]
## Orientation inferred to be along y-axis; override with `position_quasirandom(orientation = 'x')`
## gg[gg4]

For which dataset is the FDR not controlled?

de_power_fdr_with_label %>%
  filter(nominal_fdr == 0.01) %>%
  filter(is.na(es)) %>%
  mutate(FDR = ifelse(is.na(FDR), 0, FDR)) %>%
  filter(FDR > 0.1) %>%
  dplyr::count(data, vals) %>%
  arrange(desc(n))
## # A tibble: 8 × 3
##   data          vals      n
##   <chr>         <chr> <int>
## 1 hrvatin       4h        4
## 2 angelidis     24m       2
## 3 angelidis     3m        2
## 4 kang          ctrl      2
## 5 kang          stim      2
## 6 goldfarbmuren never     1
## 7 hrvatin       0h        1
## 8 sathyamurthy  form      1
power_per_desize <- read_tsv("../benchmark/output/differential_expression_fdr_power-kmeans-stratified.tsv.gz")
## Rows: 35191 Columns: 5
## ── Column specification ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr (3): method, data_val, name
## dbl (1): de_size
## lgl (1): signif
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
power_per_desize_plot <- power_per_desize %>%
  mutate(extra_settings = ifelse(str_starts(method, "lemur_edgeR_"), str_remove(method, "lemur_edgeR_"), NA)) %>%
  separate(method, into = c("method", "de_framework")) %>%
  filter(de_framework == "edgeR" & is.na(extra_settings)) %>%
  mutate(method = factor(method, levels = c("lemur", "global",  "cluster", "celltype", "miloDE"),
                         labels = c("LEMUR", "Global", "Cluster*", "Cell type*", "miloDE*"))) %>%
  ggplot(aes(x = de_size)) + 
    geom_hline(yintercept = 0, linewidth = 0.3) +
    geom_histogram(aes(fill = signif), bins = 30, show.legend = FALSE) +
    ggh4x::facet_grid2(rows = vars(method), strip = ggh4x::strip_vanilla(clip = "off")) +
    scale_fill_manual(values = c("TRUE" = "darkblue", "FALSE" = "lightgrey")) +
    scale_x_log10(limits = c(50, NA), labels = scales::label_comma()) +
    scale_y_continuous(expand = expansion(add = c(0, 10)), breaks = c(0, 100, 200)) +
    labs(title = "(K) Power by no. cell with expr. change",
         subtitle = "Fraction of simulated genes (\\textcolor{gray!60!black}{grey}) with $\\textrm{FDR} < 0.1$ (\\textcolor{blue!60!black}{blue})",
         x = "No. cells with expr. change (log scale)", y = "No. genes") +
    theme(plot.title.position = "plot")
## Warning: Expected 2 pieces. Additional pieces discarded in 24200 rows [201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
## 220, ...].
power_per_desize_plot
## Warning: Removed 50 rows containing non-finite values (`stat_bin()`).
## Warning: Removed 10 rows containing missing values (`geom_bar()`).

Variance Explained

var_expl <- read_tsv("../benchmark/output/variance_explained.tsv.gz")
## Rows: 3120 Columns: 12
## ── Column specification ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## Delimiter: "\t"
## chr (3): method, subset, dataset
## dbl (9): dimensions, var_expl, user.self, sys.self, elapsed, user.child, sys...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
var_expl_overall_pl <- var_expl %>%
  bind_rows(var_expl %>% distinct(method, subset, dataset, n_genes, n_cells) %>% mutate(dimensions = 0, var_expl = 0)) %>%
  filter(subset == "all") %>%
  filter(! is.na(var_expl)) %>%
  mutate(method = factor(method, levels = c("lemur", "pca"), labels = c("LEMUR", "PCA"))) %>%
  mutate(dataset = fct_reorder(dataset, n_cells)) %>%
  mutate(is_row_subset = n_genes == 500) %>%
  ggplot(aes(x = dimensions, y = var_expl * 100)) +
    geom_hline(yintercept = 0, linewidth = 0.3) +
    geom_line(aes(color = method, linetype = method)) +
    ggh4x::facet_grid2(vars(is_row_subset), vars(dataset), 
               labeller = ggplot2::labeller(is_row_subset = as_labeller(c("TRUE" = "500 HVG", "FALSE" = "All Genes")), dataset = as_labeller(dataset_labels)),
               strip = ggh4x::strip_vanilla(clip = "off")) +
    scale_x_continuous(breaks = c(0, 100, 200), expand = expansion(mult = c(0, 0.1))) +
    scale_y_continuous(expand = expansion(mult = c(0)), breaks = c(0, 25, 50, 75, 100), limits = c(0, 100)) +
    labs(y = "Percent variance explained", x = "Latent dimensions", color = "", linetype = "",
         title = "(A) Variance explained") +
    theme(legend.position = "bottom",
          panel.grid.major.y = element_line("lightgrey", linewidth = 0.2),
          panel.spacing.y = unit(5, "mm"), panel.spacing.x = unit(3, "mm"))

var_duration_pl1 <- var_expl %>%
  filter(subset == "all" | is.na(subset)) %>%
  filter(dimensions > 3) %>%
  mutate(method = factor(method, levels = c("lemur", "pca", "lemur_landmark", "lemur_harmony"), 
                         labels = c("LEMUR", "PCA", "LEMUR + Landmark Alignment", "LEMUR + Harmony Alignment"))) %>%
  mutate(is_row_subset = n_genes == 500) %>%
  mutate(original_n_genes = max(n_genes), .by = dataset) %>%
  mutate(dataset = fct_reorder(dataset, n_cells)) %>%
  group_by(is_row_subset) %>%
  group_map(\(data, key){
    is_subset <- key[[1]][1]
    data$is_row_subset <- is_subset
    # if(! is_subset){
    data$elapsed <- data$elapsed / 60
    # }
    ggplot(data, aes(x = dimensions, y = elapsed)) +
      geom_hline(yintercept = 0, linewidth = 0.3) +
      geom_hline(yintercept = c(), linewidth = 0.3) +
      geom_line(aes(color = method, linetype = method)) +
      ggh4x::facet_grid2(vars(is_row_subset), vars(dataset, n_cells, original_n_genes), 
                 labeller = ggplot2::labeller(is_row_subset = as_labeller(c("TRUE" = "500 HVG", "FALSE" = "All Genes")), 
                                              dataset = as_labeller(dataset_labels),
                                              n_cells = as_labeller(\(x) glue("\\tiny{{$C={scales::comma_format()(as.numeric(x))}$}}")),
                                              original_n_genes = as_labeller(\(x) glue("\\tiny{{$G={scales::comma_format()(as.numeric(x))}$}}"))),
                 strip = ggh4x::strip_vanilla(clip = "off"), scales = "free_y") +
      scale_x_continuous(breaks = c(0, 100, 200), expand = expansion(mult = c(0, 0.1))) +
      scale_y_continuous(expand = expansion(mult = c(0, 0.1))) +
      scale_color_manual(values = scales::hue_pal()(4) |> magrittr::set_names(c("LEMUR", "LEMUR + Landmark Alignment", "PCA", "LEMUR + Harmony Alignment"))) +
      scale_linetype_manual(values = c("LEMUR" = "solid", "LEMUR + Landmark Alignment"  = "22", "PCA"  = "22", "LEMUR + Harmony Alignment"  = "solid")) +
      coord_cartesian(ylim = c(0, if(is_subset) 4 else 11)) +
      labs(y = "Elapsed time [min.]",
           x = "Latent dimensions", color = "", linetype = "",
           title = "(B) Computation Time") +
      theme(legend.position = "bottom", 
            strip.text.x = element_text(color = if(is_subset) "#00000000" else "black", margin = margin(t = 0.3, b = 0, unit = "mm")),
            plot.title = element_text(color = if(is_subset) "#00000000" else "black"),
            panel.grid.major.y = element_line("lightgrey", linewidth = 0.2),
            panel.spacing.y = unit(5, "mm"), panel.spacing.x = unit(3, "mm"))
  })

    
var_expl_overall_pl 

var_duration_pl1
## [[1]]

## 
## [[2]]

var_expl %>%
  filter(dataset == "goldfarbmuren") %>%
  filter(dimensions == 50) %>%
  filter(n_genes != 500)
## # A tibble: 8 × 12
##   dimensions method        subset var_expl user.self sys.self elapsed user.child
##        <dbl> <chr>         <chr>     <dbl>     <dbl>    <dbl>   <dbl>      <dbl>
## 1         50 pca           all       0.349      373.    0.295    35.2          0
## 2         50 pca           never     0.353      373.    0.295    35.2          0
## 3         50 pca           heavy     0.344      373.    0.295    35.2          0
## 4         50 lemur         all       0.355      791.   11.1     103.           0
## 5         50 lemur         never     0.360      791.   11.1     103.           0
## 6         50 lemur         heavy     0.350      791.   11.1     103.           0
## 7         50 lemur_landma… <NA>     NA          794.   12.6     105.           0
## 8         50 lemur_harmony <NA>     NA          902.   31.4     198.           0
## # ℹ 4 more variables: sys.child <dbl>, dataset <chr>, n_genes <dbl>,
## #   n_cells <dbl>
var_expl %>%
  filter(dataset == "mouse_gastrulation") %>%
  filter(dimensions == 200) %>%
  filter(n_genes != 500)
## # A tibble: 8 × 12
##   dimensions method        subset var_expl user.self sys.self elapsed user.child
##        <dbl> <chr>         <chr>     <dbl>     <dbl>    <dbl>   <dbl>      <dbl>
## 1        200 pca           all       0.210    19669.     21.5   1303.          0
## 2        200 pca           wt        0.258    19669.     21.5   1303.          0
## 3        200 pca           chime…    0.196    19669.     21.5   1303.          0
## 4        200 lemur         all       0.217    44023.     87.0   2995.          0
## 5        200 lemur         wt        0.282    44023.     87.0   2995.          0
## 6        200 lemur         chime…    0.199    44023.     87.0   2995.          0
## 7        200 lemur_landma… <NA>     NA        44059.    114.    3015.          0
## 8        200 lemur_harmony <NA>     NA        44673.    262.    3496.          0
## # ℹ 4 more variables: sys.child <dbl>, dataset <chr>, n_genes <dbl>,
## #   n_cells <dbl>
plot_assemble(
  add_plot(var_expl_overall_pl, x = 0, y = 0, width = 170, height = 60),
  add_plot(var_duration_pl1[[1]] + guides(color =  "none", linetype = "none"), x = 0, y = 60, width = 170, height = 35),
  add_plot(var_duration_pl1[[2]], x = 0, y = 85, width = 170, height = 40),

  width = 170, height = 125, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/suppl_variance_explained.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]

Assemble

plot_assemble(
  add_plot(kang_plot, x = 0, y = 0, width = 110, height = 50),
  add_text("(B) Integration Performance", x = 111.5, y = 1.5, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(int_bio_pl, x = 110, y = 5, width = 60, height = 45),

  add_plot(kang_plot_cell_type, x = 0, y = 50, width = 30, height = 37),
  add_plot(ct_scatter_plot + guides(color = "none"), x = 27, y = 50, width = 83, height = 40),
  add_plot(cowplot::get_legend(ct_scatter_plot) , x = 10, y = 88, width = 75, height = 10),
  add_plot(perf_pred_plot, x = 110, y = 50, width = 60, height = 50),
  
  add_text("(F) Simulated DE", x = 2.7, y = 101.4, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(de_simulated_pl, x = 0, y = 100, width = 30, height = 45),
  add_plot(de_expr_pl, x = 25, y = 100, width = 30, height = 45),
  add_text("(G) Predicted DE with neighborhood", x = 54, y = 101.4, fontsize = font_size, vjust = 1, fontface = "bold"),
  add_plot(de_pred_pl, x = 53, y = 100, width = 40, height = 45),
  add_plot(de_nei_pl, x = 83, y = 100, width = 35, height = 45),
  add_plot(prec_recall_plot + guides(color = "none"), x = 110, y = 100, width = 60, height = 42),
  add_plot(cowplot::get_legend(prec_recall_plot), x = 115, y = 140, width = 50, height = 5),
  
  add_plot(fdr_control_plot, x = 0, y = 145, width = 40, height = 60),
  add_plot(power_plot, x = 40, y = 145, width = 68, height = 60),
  add_plot(power_per_desize_plot, x = 110, y = 145, width = 60, height = 60),
  
  width = 170, height = 205, units = "mm", show_grid_lines = FALSE,
  latex_support = TRUE, filename = "../plots/performance_validation.pdf"
)
## gg[gg1]
## gg[gg2]
## gg[gg3]
## gg[gg4]
## gg[gg5]
## gg[gg6]
## Warning: Removed 12 rows containing non-finite values (`stat_summary()`).
## Orientation inferred to be along y-axis; override with
## `position_quasirandom(orientation = 'x')`
## Warning: Removed 12 rows containing missing values (`position_quasirandom()`).
## gg[gg7]
## gg[gg8]
## gg[gg9]
## gg[gg10]
## gg[gg11]
## gg[gg12]
## gg[gg13]
## gg[gg14]
## gg[gg15]
## gg[gg16]
## gg[gg17]
## Warning: Removed 50 rows containing non-finite values (`stat_bin()`).
## Warning: Removed 10 rows containing missing values (`geom_bar()`).
## gg[gg18]
##  [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
## [16] TRUE TRUE

Session Info

sessionInfo()
## R version 4.3.2 (2023-10-31)
## Platform: x86_64-apple-darwin20 (64-bit)
## Running under: macOS Big Sur 11.7.10
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.3-x86_64/Resources/lib/libRblas.0.dylib 
## LAPACK: /Library/Frameworks/R.framework/Versions/4.3-x86_64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## time zone: Europe/Berlin
## tzcode source: internal
## 
## attached base packages:
## [1] stats4    stats     graphics  grDevices datasets  utils     methods  
## [8] base     
## 
## other attached packages:
##  [1] lemur_1.1.5                 glue_1.6.2                 
##  [3] lubridate_1.9.3             forcats_1.0.0              
##  [5] stringr_1.5.0               dplyr_1.1.3                
##  [7] purrr_1.0.2                 readr_2.1.4                
##  [9] tidyr_1.3.0                 tibble_3.2.1               
## [11] ggplot2_3.4.3               tidyverse_2.0.0            
## [13] SingleCellExperiment_1.22.0 SummarizedExperiment_1.30.2
## [15] Biobase_2.60.0              GenomicRanges_1.52.1       
## [17] GenomeInfoDb_1.36.4         IRanges_2.34.1             
## [19] S4Vectors_0.38.2            BiocGenerics_0.46.0        
## [21] MatrixGenerics_1.12.3       matrixStats_1.0.0          
## 
## loaded via a namespace (and not attached):
##  [1] bitops_1.0-7             rlang_1.1.1              magrittr_2.0.3          
##  [4] RcppAnnoy_0.0.21         compiler_4.3.2           png_0.1-8               
##  [7] vctrs_0.6.3              pkgconfig_2.0.3          crayon_1.5.2            
## [10] fastmap_1.1.1            XVector_0.40.0           labeling_0.4.3          
## [13] utf8_1.2.3               Rsamtools_2.16.0         rmarkdown_2.25          
## [16] tzdb_0.4.0               ggbeeswarm_0.7.2         strawr_0.0.91           
## [19] bit_4.0.5                xfun_0.40                zlibbioc_1.46.0         
## [22] cachem_1.0.8             jsonlite_1.8.7           DelayedArray_0.26.7     
## [25] BiocParallel_1.34.2      irlba_2.3.5.1            parallel_4.3.2          
## [28] R6_2.5.1                 plyranges_1.20.0         bslib_0.5.1             
## [31] stringi_1.7.12           RColorBrewer_1.1-3       rtracklayer_1.60.1      
## [34] jquerylib_0.1.4          Rcpp_1.0.11              knitr_1.44              
## [37] clisymbols_1.2.0         filehash_2.4-5           Matrix_1.6-1.1          
## [40] timechange_0.2.0         tidyselect_1.2.0         rstudioapi_0.15.0       
## [43] abind_1.4-5              yaml_2.3.7               codetools_0.2-19        
## [46] curl_5.1.0               lattice_0.21-9           withr_2.5.1             
## [49] ggrastr_1.0.2            evaluate_0.22            gridGraphics_0.5-1      
## [52] isoband_0.2.7            Biostrings_2.68.1        pillar_1.9.0            
## [55] BiocManager_1.30.22      renv_1.0.3               generics_0.1.3          
## [58] vroom_1.6.4              RCurl_1.98-1.12          tidylog_1.0.2           
## [61] plotgardener_1.6.4       hms_1.1.3                munsell_0.5.0           
## [64] scales_1.2.1             tikzDevice_0.12.5        tools_4.3.2             
## [67] BiocIO_1.10.0            data.table_1.14.8        GenomicAlignments_1.36.0
## [70] fs_1.6.3                 XML_3.99-0.14            Cairo_1.6-1             
## [73] cowplot_1.1.1            grid_4.3.2               colorspace_2.1-0        
## [76] GenomeInfoDbData_1.2.10  beeswarm_0.4.0           vipor_0.4.5             
## [79] restfulr_0.0.15          cli_3.6.1                fansi_1.0.5             
## [82] S4Arrays_1.0.6           uwot_0.1.16              gtable_0.3.4            
## [85] glmGamPoi_1.12.2         ggh4x_0.2.6              yulab.utils_0.1.0       
## [88] sass_0.4.7               digest_0.6.33            ggrepel_0.9.4           
## [91] ggplotify_0.1.2          rjson_0.2.21             farver_2.1.1            
## [94] memoise_2.0.1            htmltools_0.5.6.1        lifecycle_1.0.3         
## [97] MASS_7.3-60              bit64_4.0.5